{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "6d00ba42-015f-4ec9-9848-4e31026285e4",
   "metadata": {},
   "source": [
    "https://scikit-learn.org/1.1/modules/clustering.html#k-means"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d15b814-d6f1-4965-ab6b-e4436999ee27",
   "metadata": {},
   "source": [
    "## Training and Logging"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "4f68bf31-708f-4fc8-88e5-85832fb0bbee",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((150, 4), (150,))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sklearn.datasets import load_iris\n",
    "X, y = load_iris(return_X_y=True)\n",
    "(X.shape, y.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "cee3c8b7-961d-41dd-bbbd-23b30f848098",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(0, 1, 2)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(y[0], y[50], y[100])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "d626f3d8-ff3a-4a6c-988f-a7898fe5521f",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/da/.cache/pants/named_caches/pex_root/venvs/s/378f4125/venv/lib/python3.9/site-packages/sklearn/cluster/_kmeans.py:870: FutureWarning: The default value of `n_init` will change from 10 to 'auto' in 1.4. Set the value of `n_init` explicitly to suppress the warning\n",
      "  warnings.warn(\n",
      "/home/da/.cache/pants/named_caches/pex_root/venvs/s/378f4125/venv/lib/python3.9/site-packages/_distutils_hack/__init__.py:33: UserWarning: Setuptools is replacing distutils.\n",
      "  warnings.warn(\"Setuptools is replacing distutils.\")\n",
      "2023/03/29 13:14:00 WARNING mlflow.utils.environment: Failed to resolve installed pip version. ``pip`` will be added to conda.yaml environment spec without a version specifier.\n",
      "Registered model 'da_kmeans' already exists. Creating a new version of this model...\n",
      "2023/03/29 13:14:00 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.                     Model name: da_kmeans, version 2\n",
      "Created version '2' of model 'da_kmeans'.\n",
      "/home/da/.cache/pants/named_caches/pex_root/venvs/s/378f4125/venv/lib/python3.9/site-packages/liga/mlflow/logger.py:137: UserWarning: value of rikai.output.schema is None or empty and will not be populated to MLflow\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "import getpass\n",
    "\n",
    "import mlflow\n",
    "from liga.sklearn.mlflow import log_model\n",
    "from sklearn.cluster import KMeans\n",
    "\n",
    "\n",
    "mlflow_tracking_uri = \"sqlite:///mlruns.db\"\n",
    "mlflow.set_tracking_uri(mlflow_tracking_uri)\n",
    "\n",
    "# train a model\n",
    "with mlflow.start_run() as run:\n",
    "    ####\n",
    "    # Part 1: Train the model and register it on MLflow\n",
    "    ####\n",
    "    model = KMeans(n_clusters=3, random_state=0)\n",
    "    model.fit(X)\n",
    "    registered_model_name = f\"{getpass.getuser()}_kmeans\"\n",
    "    log_model(model, registered_model_name=registered_model_name)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a898c9e-652e-4e0f-b4dc-5cc33444f18f",
   "metadata": {},
   "source": [
    "## Apply the model on large scale dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e3d3fdb6-1b67-4c2b-ad32-4f7164460eb9",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2023-03-29 13:14:00,268 INFO Rikai (__init__.py:121): setting spark.sql.extensions to net.xmacs.liga.spark.RikaiSparkSessionExtensions\n",
      "2023-03-29 13:14:00,268 INFO Rikai (__init__.py:121): setting spark.driver.extraJavaOptions to -Dio.netty.tryReflectionSetAccessible=true\n",
      "2023-03-29 13:14:00,269 INFO Rikai (__init__.py:121): setting spark.executor.extraJavaOptions to -Dio.netty.tryReflectionSetAccessible=true\n",
      "2023-03-29 13:14:00,270 INFO Rikai (__init__.py:121): setting spark.jars to https://github.com/komprenilo/liga/releases/download/v0.3.0/liga-spark321-assembly_2.12-0.3.0.jar\n",
      "23/03/29 13:14:01 WARN Utils: Your hostname, tubi resolves to a loopback address: 127.0.1.1; using 192.168.31.32 instead (on interface wlp0s20f3)\n",
      "23/03/29 13:14:01 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address\n",
      "Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties\n",
      "Setting default log level to \"WARN\".\n",
      "To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).\n",
      "23/03/29 13:14:06 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable\n",
      "23/03/29 13:14:06 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "+----------------+------+-------------------+-------+\n",
      "|name            |plugin|uri                |options|\n",
      "+----------------+------+-------------------+-------+\n",
      "|mlflow_sklearn_m|      |mlflow:///da_kmeans|       |\n",
      "+----------------+------+-------------------+-------+\n",
      "\n"
     ]
    }
   ],
   "source": [
    "from example import spark\n",
    "from liga.mlflow import CONF_MLFLOW_TRACKING_URI\n",
    "spark.conf.set(\"spark.sql.execution.arrow.pyspark.enabled\", \"false\")\n",
    "spark.conf.set(CONF_MLFLOW_TRACKING_URI, mlflow_tracking_uri)\n",
    "spark.sql(f\"\"\"\n",
    "CREATE OR REPLACE MODEL mlflow_sklearn_m LOCATION 'mlflow:///{registered_model_name}';\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "spark.sql(\"show models\").show(1, vertical=False, truncate=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "89f96830-40d5-4a2c-8e0c-adbaa6e30c4e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "root\n",
      " |-- y0: integer (nullable = true)\n",
      " |-- y50: integer (nullable = true)\n",
      " |-- y100: integer (nullable = true)\n",
      "\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                                                \r"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>y0</th>\n",
       "      <th>y50</th>\n",
       "      <th>y100</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>2</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   y0  y50  y100\n",
       "0   1    0     2"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from liga.numpy.sql import literal\n",
    "\n",
    "result = spark.sql(f\"\"\"\n",
    "select\n",
    "  ML_PREDICT(mlflow_sklearn_m, {literal(X[0])}) as y0,\n",
    "  ML_PREDICT(mlflow_sklearn_m, {literal(X[50])}) as y50,\n",
    "  ML_PREDICT(mlflow_sklearn_m, {literal(X[100])}) as y100\n",
    "\"\"\"\n",
    ")\n",
    "\n",
    "result.printSchema()\n",
    "result.toPandas()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38e256be-3299-4ac5-8bdd-b8cfd1e8e5e9",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
